import tensorflow as tf
import os
import multiprocessing as mp
import cv2

IMAGE_HEIGHT = 360
IMAGE_WIDTH = 480
IMAGE_DEPTH = 3
IMAGE_SHAPE = (IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_DEPTH)
IMAGE_SIZE = IMAGE_HEIGHT * IMAGE_WIDTH * IMAGE_DEPTH
LABEL_SIZE = IMAGE_HEIGHT * IMAGE_WIDTH
LABEL_SHAPE = (IMAGE_HEIGHT, IMAGE_WIDTH, 1)

NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 367

def _bytes_feature(value):
    """Returns a byte_list from a string / byte."""
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _int64_feature(value):
    """Returns a int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def image2tfrecords(data_dir, label_dir, tf_record_path):
    """Change image data to tfrecords.
        :param data_dir: the image data path
        :param label_dir: the label data path
        :param tf_record_path: the final output tfrecords.

        Note:This method should be used in initial step to create tfrecords dataset.
    """
    tf_writer = tf.python_io.TFRecordWriter(tf_record_path)
    count = 0
    for image_name in os.listdir(data_dir):
        count += 1
        image_path = os.path.join(data_dir, image_name)
        image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)

        label_path = os.path.join(label_dir, image_name)
        label = cv2.imread(label_path, cv2.IMREAD_UNCHANGED)

        example = tf.train.Example(features=tf.train.Features(feature={
            'image_array': _bytes_feature(image.tobytes()),
            'label_array': _bytes_feature(label.tobytes()),
        }))
        tf_writer.write(example.SerializeToString())
    tf_writer.close()


def parse_record(example):
    """Read the data in tfrecord.
    :params example: the tf.train.Example in tfrecord.

    :return:the image and label data.
    """
    features = {
        'image_array': tf.FixedLenFeature([], tf.string),
        'label_array': tf.FixedLenFeature([], tf.string)
    }

    parsed = tf.parse_single_example(example, features=features)
    image = parsed['image_array']
    label = parsed['label_array']
    image = tf.decode_raw(image, tf.uint8)
    label = tf.decode_raw(label, tf.uint8)
    image = tf.cast(tf.reshape(image, IMAGE_SHAPE), dtype=tf.float32)
    label = tf.cast(tf.reshape(label, LABEL_SHAPE), dtype=tf.int64)

    return image, label


def CamVidInputs(tfrecord_path, batch_size, shuffle):
    """create dataset with batch size.
    :params tfrecord_path: the path of tfrecord.
    :params batch_size: batch size of the data size.

    :return: the dataset used in training and testing.
    """
    dataset = tf.data.TFRecordDataset([tfrecord_path])
    if shuffle:
        min_fraction_of_examples_in_queue = 0.4
        buffer_size = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * min_fraction_of_examples_in_queue)
        dataset = dataset.shuffle(buffer_size=buffer_size)
    dataset = dataset.map(parse_record, mp.cpu_count())
    dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(64)
    dataset = dataset.repeat()
    return dataset


def initial(record_path, train_path_list, val_path_list, test_path_list):
    """convert CamVid image dataset to tfrecord dataset when first training.
        :params record_path: the path of tfrecord dataset
        :params train_path_list: the train data info
        :params val_path_list: the val data info
        :params test_path_list: the test data info
    """
    if not os.path.exists(record_path):
        print("creating tfrecord dataset path......")
        os.makedirs(record_path)

    if not os.path.exists(train_path_list[2]):
        os.mknod(train_path_list[2])
        print("creating train tfrecord dataset, it will take a few minute......")
        train_info = [train_path_list[0], train_path_list[1], train_path_list[2]]
        image2tfrecords(data_dir=train_info[0], label_dir=train_info[1], tf_record_path=train_info[2])

    if not os.path.exists(val_path_list[2]):
        os.mknod(val_path_list[2])
        print("creating val tfrecord dataset, it will take a few minute......")
        val_info = [val_path_list[0], val_path_list[1], val_path_list[2]]
        image2tfrecords(data_dir=val_info[0], label_dir=val_info[1], tf_record_path=val_info[2])

    if not os.path.exists(test_path_list[2]):
        os.mknod(test_path_list[2])
        print("creating test tfrecord dataset, it will take a few minute......")
        test_info = [test_path_list[0], test_path_list[1], test_path_list[2]]
        image2tfrecords(data_dir=test_info[0], label_dir=test_info[1], tf_record_path=test_info[2])


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description='manual to this script')
    parser.add_argument('--data_path', type=str, default=None)
    parser.add_argument('--label_dir', type=str, default=None)
    parser.add_argument('--tf_record_path', type=str, default=None)
    args = parser.parse_args()
    image2tfrecords(data_dir=args.data_path, label_dir=args.label_dir, tf_record_path=args.tf_record_path)
